/******************************************************************************
This file is part of AppKit.
Project: appkit
Author : FergusZeng
Email  : cblock@126.com
git	   : https://gitee.com/newgolo/appkit.git
*******************************************************************************
MIT License

Copyright (c) 2022 cblock@126.com

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
******************************************************************************/
#include "appkit/thread.h"

#include <fcntl.h>
#include <limits.h>
#include <signal.h>
#include <stdlib.h>
#include <sys/ipc.h>
#include <sys/mman.h>
#include <sys/prctl.h>
#include <sys/sem.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>

#include <chrono>

#include "appkit/tracer.h"

namespace appkit {
Runnable::Runnable() {}
Runnable::~Runnable() { m_thread = nullptr; }
bool Runnable::isExited() {
    if (m_thread) {
        return !m_thread->isRunning();
    }
    return true;
}

Thread::Thread() {}
Thread::Thread(const std::string& name) : m_name(name) {}
Thread::Thread(const Thread& copy) {}

Thread::~Thread() { stop(); }

std::unique_ptr<Thread> Thread::clone() const {
    return std::make_unique<Thread>(*this);
}

bool Thread::start(const Runnable& runnable) {
    if (m_running.load()) {
        TRACE_WARN_CLASS("[%s](0x%x) thread is allready started!",
                         m_name.data(), threadID());
        return true;
    }
    m_runnable = const_cast<Runnable*>(&runnable);
    try {
        m_future = std::async(std::launch::async, [&]() {
            if (!m_name.empty()) {
                Thread::setName(m_name);
            }
            TRACE_INFO_CLASS("[%s](0x%x) thread enter.", m_name.data(),
                             threadID());
            m_running.store(true);
            m_runnable->m_thread = this;
            m_runnable->run(*this);
            m_runnable->m_thread = nullptr;
            m_running.store(false);
            TRACE_INFO_CLASS("[%s](0x%x) thread exit.", m_name.data(),
                             threadID());
        });
    } catch (const std::system_error& e) { /* 当系统资源不够时,抛出系统异常 */
        TRACE_ERR_CLASS("[%s](0x%x) thread start error: %s", m_name.data(),
                        threadID(), e.what());
        return false;
    }
    return true;
}

bool Thread::stop(int msTimeout) {
    m_running.store(false);
    if (m_future.valid()) {
        m_future.wait();
    }
    return true;
}

bool Thread::isRunning() const { return m_running.load(); }

void Thread::usleep(int us) {
    us = (us <= 0) ? 1 : us;
    std::this_thread::sleep_for(std::chrono::microseconds(us));
}

void Thread::msleep(int ms) {
    ms = (ms <= 0) ? 1 : ms;
    std::this_thread::sleep_for(std::chrono::milliseconds(ms));
}

int Thread::threadID() {
    std::thread::id tid = std::this_thread::get_id();
    return *(reinterpret_cast<int*>(&tid));
}

bool Thread::setAffinity(int cpuId) {
    cpu_set_t mask;
    CPU_ZERO(&mask);
    CPU_SET(cpuId, &mask);
    if (0 != pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask)) {
        return false;
    }
    return true;
}

void Thread::setName(const std::string& name) {
    if (!name.empty()) {
        prctl(PR_SET_NAME, name.data());
    }
}

PThread::PThread() {}

PThread::PThread(int policy, int priority, bool inherit, int stackSize) {
    m_threadAttribute.m_policy = policy;
    if (policy == SCHED_POLICY_OTHER) {
        priority = 0;
    }
    m_threadAttribute.m_priority = priority;
    m_threadAttribute.m_stackSize = stackSize;
    m_threadAttribute.m_inherit =
        inherit ? PTHREAD_INHERIT_SCHED : PTHREAD_EXPLICIT_SCHED;
    if (stackSize < PTHREAD_STACK_MIN) {
        m_threadAttribute.m_stackSize = PTHREAD_STACK_MIN;
    } else {
        m_threadAttribute.m_stackSize = stackSize;
    }
}

PThread::~PThread() { stop(); }

std::unique_ptr<Thread> PThread::clone() const {
    return std::make_unique<PThread>(*this);
}

bool PThread::start(const Runnable& runnable) {
    pthread_attr_t* pAttr = nullptr;
    if (m_status != STATE_INIT) {
        TRACE_ERR_CLASS("thread state error, state: %d", m_status);
        return false;
    }

    m_runnable = const_cast<Runnable*>(&runnable);
    /* 设置线程属性 */
    pthread_attr_init(&m_attribute);
    if (setAttribute(m_attribute)) {
        pAttr = &m_attribute;
    }
    m_status = STATE_START;
    if (0 != pthread_create(&m_threadID, pAttr, startRoutine, this)) {
        TRACE_ERR_CLASS("pthread create error: %s!", ERRSTR);
        return false;
    }
    pthread_attr_destroy(&m_attribute);
    return true;
}

bool PThread::stop(int msTimeout) {
    if (!m_runFlag) {
        return true;
    }
    m_runFlag = false;
    if (msTimeout < 0) {
        msTimeout = MAX_SINT32;
    }
    msTimeout = MAX(10, msTimeout);
    while (msTimeout > 0) {
        if (m_status == STATE_EXIT) {
            pthread_join(m_threadID, nullptr);
            m_threadID = 0;
            return true;
        }
        msleep(10);
        msTimeout -= 10;
    }
    return false;
}

bool PThread::isRunning() const {
    if (!m_runnable || !m_runFlag) {
        return false;
    }
    return (m_status == STATE_RUNNING) ? true : false;
}

void PThread::usleep(int us) {
    struct timespec ts;
    if (us <= 0) {
        us = 1;
    }
    ts.tv_sec = us / 1000000;
    if (ts.tv_sec == 0) {
        ts.tv_nsec = us * 1000;
    } else {
        ts.tv_nsec = (us % 1000000) * 1000;
    }
    clock_nanosleep(CLOCK_MONOTONIC, 0, &ts,
                    nullptr); /* 使用MONITONIC时钟,不受更改系统时钟的影响 */
}

void PThread::msleep(int ms) { PThread::usleep(ms * 1000); }

bool PThread::setAttribute(const pthread_attr_t& threadAttr) {
    struct sched_param param;
    pthread_attr_t attribute = threadAttr;
    if (pthread_attr_setstacksize(&attribute, m_threadAttribute.m_stackSize) !=
        0) {
        TRACE_ERR_CLASS("Set thread attribute stack size:%d failed!",
                        m_threadAttribute.m_stackSize);
        return false;
    }

    /* 设置调度策略 */
    if (pthread_attr_setschedpolicy(&attribute, m_threadAttribute.m_policy) !=
        0) {
        TRACE_ERR_CLASS("Set thread attribute policy:%d failed",
                        m_threadAttribute.m_policy);
        return false;
    }

    /* 设置调度优先级 */
    param.sched_priority = m_threadAttribute.m_priority;
    if (pthread_attr_setschedparam(&attribute, &param) != 0) {
        TRACE_ERR_CLASS("Set thread attribute priority:%d failed",
                        m_threadAttribute.m_priority);
        return false;
    }

    /* 设置线程调度策略继承属性 */
    if (pthread_attr_setinheritsched(&attribute, m_threadAttribute.m_inherit) !=
        0) {
        TRACE_ERR_CLASS("Set thread sched inherit:%d failed",
                        m_threadAttribute.m_inherit);
        return false;
    }
    return true;
}

#if 0
bool PThread::cancel() {
    if (m_threadID > 0) {
        /* 执行pthread_cancel后,并不会直接取消线程,必须等到下一次系统调用或者pthread_testcancel才会真正取消线程
         */
        if (0 != pthread_cancel(m_threadID)) {
            TRACE_ERR_CLASS("pthread cancel error:%s.", ERRSTR);
            return false;
        }
        pthread_join(m_threadID, nullptr);
    }
    m_status = STATE_EXIT;
    m_threadID = 0;
    return true;
}

void PThread::setCancelPoint() {
    pthread_testcancel();
}
#endif

/* 线程运行函数 */
void PThread::threadMain() {
    m_status = STATE_RUNNING;
    m_threadID = pthread_self();
    if (m_runnable) {
        m_runFlag = true;
        m_runnable->run(*this);
        m_runFlag = false;
    }
    m_status = STATE_EXIT;
    m_threadID = 0;
    pthread_exit(nullptr);
}

/* 线程入口点 */
void* PThread::startRoutine(void* arg) {
    // 设置detach,在线程结束后自动回收资源,不再需要pthread_join
    pthread_detach(pthread_self());
    PThread* pThread = reinterpret_cast<PThread*>(arg);
    if (pThread) {
        pThread->threadMain();
    }
    return nullptr;
}

/**
 * @struct ThreadArgs
 * @brief 线程参数
 */
struct ThreadArgs {
    Threading* m_owner;
    std::function<void(void*)> m_entry;
    void* m_args;
};

Threading::Threading() {}

Threading::~Threading() {
    for (auto& future : m_futures) {
        if (future.valid()) {
            future.wait();
        }
    }
}

bool Threading::startPThreading(std::function<void(void*)> func) {
    pthread_t threadID;
    auto threadArgs = std::make_unique<ThreadArgs>();
    threadArgs->m_owner = this;
    threadArgs->m_entry = func;
    threadArgs->m_args = nullptr;
    if (!threadArgs->m_entry) {
        TRACE_ERR_CLASS("thread function cannot be nullptr!");
        return false;
    }
    if (0 != pthread_create(&threadID, nullptr, startRoutine,
                            threadArgs.release())) {
        TRACE_ERR_CLASS("threading failed:%s!", ERRSTR);
        return false;
    }
    return true;
}

/* 任务入口点 */
void* Threading::startRoutine(void* arg) {
    /* detach后不再需要pthread_join(),会自动回收线程资源 */
    pthread_detach(pthread_self());
    std::unique_ptr<ThreadArgs> threadArgs(reinterpret_cast<ThreadArgs*>(arg));
    if (threadArgs->m_owner && threadArgs->m_entry) {
        threadArgs->m_entry(threadArgs->m_args);
    }
    return nullptr;
}

void RWMutex::lockR() {
    m_rMutex.lock();
    if (++m_readCnts == 1) {
        m_wMutex.lock(); /* 存在线程读操作时,写加锁(只加一次) */
    }
    m_rMutex.unlock();
}

void RWMutex::unLockR() {
    m_rMutex.lock();
    if (--m_readCnts == 0) {
        m_wMutex.unlock(); /* 没有线程读操作时,释放写锁 */
    }
    m_rMutex.unlock();
}

void RWMutex::lockW() { m_wMutex.lock(); }

void RWMutex::unLockW() { m_wMutex.unlock(); }

Semaphore::Semaphore() {}
Semaphore::~Semaphore() {}
bool Semaphore::open(const std::string& name, int value) {
    m_name = name;
    if (m_name.empty()) {
        if (value < 0) {
            return false;
        }
        /* 初始化无名信号量 */
        if (-1 == sem_init(&m_sem, 0, value)) {
            TRACE_ERR_CLASS("sem init error:%s", ERRSTR);
            return false;
        }
    } else {
        /* 有名信号量,名称必须以'/'开头,see"sem_overview" */
        m_name = std::string("/") + m_name;
        sem_t* sem =
            sem_open(CSTR(m_name), O_CREAT | O_RDWR | O_EXCL, 0664, value);
        if (sem == SEM_FAILED) { /* 信号量已存在 */
            if (value < 0) {     /* 仅打开,不初始化值 */
                sem = sem_open(CSTR(m_name), 0);
                if (sem == SEM_FAILED) {
                    TRACE_ERR_CLASS("sem_open error: %s", ERRSTR);
                    return false;
                }
            } else { /* 重新打开并初始化值 */
                unlink();
                sem = sem_open(CSTR(m_name), O_CREAT | O_RDWR | O_EXCL, 0664,
                               value);
                if (sem == SEM_FAILED) {
                    TRACE_ERR_CLASS("sem open with value(%d) error: %s", value,
                                    ERRSTR);
                    return false;
                }
            }
        }
        m_sem = *sem;
    }
    return true;
}

bool Semaphore::close() {
    if (m_name.empty()) {
        if (sem_destroy(&m_sem) != 0) {
            TRACE_ERR_CLASS("sem destroy error:%s", ERRSTR);
            return false;
        }
    } else {
        if (sem_close(&m_sem) != 0) {
            TRACE_ERR_CLASS("sem close error:%s", ERRSTR);
            return false;
        }
    }
    return true;
}

bool Semaphore::unlink() {
    if (!m_name.empty()) {
        if (sem_unlink(CSTR(m_name)) != 0) {
            TRACE_ERR_CLASS("sem unlink(%s) error:%s", CSTR(m_name), ERRSTR);
            return false;
        }
    }
    return true;
}

bool Semaphore::wait(int msTimeout) {
    if (0 == sem_wait(&m_sem)) {
        return true;
    }
    return false;
}

bool Semaphore::tryWait() {
    if (0 == sem_trywait(&m_sem)) {
        return true;
    }
    return false;
}

bool Semaphore::post() {
    if (0 == sem_post(&m_sem)) {
        return true;
    }
    return false;
}

bool Semaphore::getValue(int* value) {
    if (-1 == sem_getvalue(&m_sem, value)) {
        return false;
    }
    return true;
}

/**
 * @union semun
 * @brief systemV信号量属性
 */
union semun {
    int val;              /* Value for SETVAL */
    struct semid_ds* buf; /* Buffer for IPC_STAT, IPC_SET */
    unsigned short* array; /* Array for GETALL, SETALL */ /* NOLINT */
    struct seminfo* __buf; /* Buffer for IPC_INFO (Linux-specific) */
};
SemaphoreV::SemaphoreV() {}
SemaphoreV::~SemaphoreV() {}
bool SemaphoreV::open(const std::string& name, int value) {
    if (m_semid >= 0) {
        TRACE_ERR_CLASS("sem[%s] is allready opened!", CSTR(m_name));
        return false;
    }
    auto key = ftok(CSTR(name), 1);
    if (key < 0) {
        TRACE_ERR_CLASS("sem[%s] get key error: %s!", CSTR(name), ERRSTR);
        return false;
    }
    int semid = semget((key_t)key, 1, IPC_CREAT | IPC_EXCL | 0666);
    if (semid < 0) {
        // IPC_CREAT | IPC_EXCL | S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP
        semid = semget((key_t)key, 1, IPC_CREAT | 0666);
        if (semid < 0) {
            TRACE_ERR_CLASS("sem[%s] open error: %s", CSTR(name), ERRSTR);
            return false;
        }
        if (value >= 0) { /* 重新打开并初始化值 */
            union semun sem_union;
            if (semctl(semid, 0, IPC_RMID, sem_union) < 0) {
                TRACE_ERR_CLASS("sem[%s] rmid error: %s!", CSTR(name), ERRSTR);
                return false;
            }
            semid = semget((key_t)key, 1, IPC_CREAT | 0666);
        }
    }
    if (semid < 0) {
        TRACE_ERR_CLASS("sem[%s] open error: %s", CSTR(name), ERRSTR);
        return false;
    }
    if (value >= 0) {
        union semun sem_union;
        sem_union.val = value;
        if (semctl(semid, 0, SETVAL, sem_union) < 0) {
            TRACE_ERR_CLASS("sem[%s] setval(%d) error: %s", CSTR(name), value,
                            ERRSTR);
            this->close();
            return false;
        }
    }
    m_name = name;
    m_semid = semid;
    return true;
}

bool SemaphoreV::close() {
    if (m_semid < 0) {
        return true;
    }
    union semun sem_union;
    if (semctl(m_semid, 0, IPC_RMID, sem_union) < 0) {
        TRACE_ERR_CLASS("sem[%s] close error: %s!", CSTR(m_name), ERRSTR);
        return false;
    }
    m_semid = -1;
    return true;
}

bool SemaphoreV::wait(int msTimeout) {
    if (m_semid < 0) {
        return true;
    }
    struct sembuf semb;
    semb.sem_num = 0;
    semb.sem_op = -1;
    semb.sem_flg = SEM_UNDO;
    if (semop(m_semid, &semb, 1) < 0) {
        TRACE_ERR_CLASS("sem[%s] wait error: %s!", CSTR(m_name), ERRSTR);
        return false;
    }
    return true;
}

bool SemaphoreV::post() {
    if (m_semid < 0) {
        return true;
    }
    struct sembuf semb;
    semb.sem_num = 0;
    semb.sem_op = 1;
    semb.sem_flg = SEM_UNDO;
    if (semop(m_semid, &semb, 1) < 0) {
        TRACE_ERR_CLASS("sem[%s] post error: %s!", CSTR(m_name), ERRSTR);
        return false;
    }
    return true;
}

bool SemaphoreV::getValue(int* value) {
    if (m_semid < 0) {
        TRACE_ERR_CLASS("sem[%s] not open!", CSTR(m_name));
        return false;
    }
    union semun sem_union;
    if (semctl(m_semid, 0, GETVAL, sem_union) < 0) {
        TRACE_ERR_CLASS("sem[%s] getval error: %s!", CSTR(m_name), ERRSTR);
        return false;
    }
    *value = sem_union.val;
    return true;
}

EventQueue::EventQueue() {}
EventQueue::~EventQueue() {}

void EventQueue::clearEvent() {
    std::lock_guard<std::mutex> lock(m_cvMutex);
    m_eventQueue.clear();
}

bool EventQueue::notifyEvent(const int& event) {
    if (event < 0) {
        TRACE_ERR_CLASS("event cannot less than 0!");
        return false;
    }
    if (!m_isWaiting.load()) {
        TRACE_ERR_CLASS("no waiting started!");
        return false;
    }
    {
        std::lock_guard<std::mutex> lock(m_cvMutex);
        m_eventQueue.emplace_back(event);
    }
    m_cond.notify_all();
    return true;
}

int EventQueue::waitEvent(const std::set<int>& eventSet, uint32 msTimeout) {
    uint64 usTimeout = msTimeout * 1000;
    Time startTime = Time::fromMono();
    m_isWaiting.store(true);
    for (;;) {
        auto interval = Time::usSinceMono(startTime);
        if (interval >= usTimeout) {
            m_isWaiting.store(false);
            return RC_TIMEOUT;
        }
        {
            int ret = -1;
            std::unique_lock<std::mutex> lock(m_cvMutex);
            if (m_cond.wait_for(lock, std::chrono::milliseconds(10), [&] {
                    for (auto& event : m_eventQueue) {
                        if (eventSet.find(event) != eventSet.end()) {
                            ret = event;
                            break;
                        }
                    }
                    return (ret >= 0) ? true : false;
                })) {
                if (ret >= 0) {
                    m_isWaiting.store(false);
                    return ret;
                }
            }
        }
    }
}

}  // namespace appkit
